package org.xqh.study.leetcode.algorithm.median;

import com.alibaba.fastjson.JSON;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;

/**
 * @ClassName MedianSlidingWindow
 * @Description 滑动窗口中位数
 * https://leetcode-cn.com/problems/sliding-window-median/
 * @Author xuqianghui
 * @Date 2021/2/3 14:42
 * @Version 1.0
 */
public class MedianSlidingWindow {

    public static void main(String[] args) {
        PriorityQueue<Integer> queue = new PriorityQueue<>(new Comparator<Integer>() {
            @Override
            public int compare(Integer num1, Integer num2) {
                return num2.compareTo(num1);
            }
        });

        queue.offer(3);
        queue.offer(5);
        queue.offer(2);
        queue.offer(11);
        queue.offer(6);
        queue.offer(7);
        queue.offer(1);
        System.out.println(JSON.toJSONString(queue));
    }



    public static class Solution {
        public double[] medianSlidingWindow(int[] nums, int k) {
            DualHeap dh = new DualHeap(k);
            for (int i = 0; i < k; ++i) {
                dh.insert(nums[i]);
            }
            double[] ans = new double[nums.length - k + 1];
            ans[0] = dh.getMedian();
            for (int i = k; i < nums.length; ++i) {
                dh.insert(nums[i]);
                dh.erase(nums[i - k]);
                ans[i - k + 1] = dh.getMedian();
            }
            return ans;
        }
    }

    public static class DualHeap {
        // 大根堆，维护较小的一半元素
        private PriorityQueue<Integer> small;
        // 小根堆，维护较大的一半元素
        private PriorityQueue<Integer> large;
        // 哈希表，记录「延迟删除」的元素，key 为元素，value 为需要删除的次数
        private Map<Integer, Integer> delayed;

        private int k;
        // small 和 large 当前包含的元素个数，需要扣除被「延迟删除」的元素
        private int smallSize, largeSize;

        public DualHeap(int k) {
            this.small = new PriorityQueue<Integer>(new Comparator<Integer>() {
                public int compare(Integer num1, Integer num2) {
                    return num2.compareTo(num1);
                }
            });
            this.large = new PriorityQueue<Integer>(new Comparator<Integer>() {
                public int compare(Integer num1, Integer num2) {
                    return num1.compareTo(num2);
                }
            });
            this.delayed = new HashMap<Integer, Integer>();
            this.k = k;
            this.smallSize = 0;
            this.largeSize = 0;
        }

        public double getMedian() {
            return (k & 1) == 1 ? small.peek() : ((double) small.peek() + large.peek()) / 2;
        }

        public void insert(int num) {
            if (small.isEmpty() || num <= small.peek()) {
                small.offer(num);
                ++smallSize;
            } else {
                large.offer(num);
                ++largeSize;
            }
            makeBalance();
        }

        public void erase(int num) {
            delayed.put(num, delayed.getOrDefault(num, 0) + 1);
            if (num <= small.peek()) {
                --smallSize;
                if (num == small.peek()) {
                    prune(small);
                }
            } else {
                --largeSize;
                if (num == large.peek()) {
                    prune(large);
                }
            }
            makeBalance();
        }

        // 不断地弹出 heap 的堆顶元素，并且更新哈希表
        private void prune(PriorityQueue<Integer> heap) {
            while (!heap.isEmpty()) {
                int num = heap.peek();
                if (delayed.containsKey(num)) {
                    delayed.put(num, delayed.get(num) - 1);
                    if (delayed.get(num) == 0) {
                        delayed.remove(num);
                    }
                    heap.poll();
                } else {
                    break;
                }
            }
        }

        // 调整 small 和 large 中的元素个数，使得二者的元素个数满足要求
        private void makeBalance() {
            if (smallSize > largeSize + 1) {
                // small 比 large 元素多 2 个
                large.offer(small.poll());
                --smallSize;
                ++largeSize;
                // small 堆顶元素被移除，需要进行 prune
                prune(small);
            } else if (smallSize < largeSize) {
                // large 比 small 元素多 1 个
                small.offer(large.poll());
                ++smallSize;
                --largeSize;
                // large 堆顶元素被移除，需要进行 prune
                prune(large);
            }
        }
    }

    public static void main111(String[] args) {
        int[] array = new int[]{3,2,5, -1, 6, 8, 1};
        sortNumsIdx(array);
//        System.out.println(JSON.toJSONString(medianSlidingWindow(array, 2)));
    }

//    public static double[] medianSlidingWindow(int[] nums, int k) {
//        if(k > nums.length){
//            return null;
//        }
//        int len = nums.length - k + 1;//可以移动次数
//        double[] result = new double[len];
//        int[] sort = sortNumsIdx(nums);
//        int mid = (k + 1)/2;
//        //初始化 第0个位置
//        for(int i = 0; i < len; i ++){
//            double cur = 0;
//            if(k % 2 == 0){
//                //偶数 要取中间两位 的平均
//                cur = Double.valueOf(sorted[mid - 1])/2 + Double.valueOf(sorted[mid])/2;
//            }else {
//                cur = Double.valueOf(sorted[mid-1]);
//            }
//            result[i] = cur;
//        }
//        return result;
//    }
//
//    public static double getMedianNum(int[] sort, int start, int end, int[] nums){
//
//    }

    /**
     * 返回 nums数组 对应 排序后的 序号
     * @param nums
     * @return
     */
    public static int[] sortNumsIdx(int[] nums){
        int[] sort = new int[nums.length];
        for(int i = 0; i < nums.length; i++){
            sort[i] = i;
            int tmp = i;
            while (tmp > 0){
                b: for(int j = 0; j < i; j ++){
                    if(sort[j] == tmp - 1){//序号对比
                        if(nums[j] <= nums[i]){
                            tmp = 0;
                            break b;
                        }else {
                            tmp --;
                            --sort[i];
                            ++sort[j];
                        }
                    }
                }
            }
        }

        return sort;
    }

    /**
     * 每滑动一次 做一次排序 性能差, 超时
     * @param nums
     * @return
     */
    public static int[] sortList(int[] nums){
        int[] sorted = new int[nums.length];
        //冒泡排序
        for(int i = 0; i < nums.length; i++){
            sorted[i] = nums[i];
            b: for(int j = i - 1; j >= 0; j --){
                if(sorted[j] <= sorted[j + 1]){
                    break b;
                }
                int tmp = sorted[j];
                sorted[j] = sorted[j+1];
                sorted[j+1] = tmp;
            }
        }
        return sorted;
    }

}
